import os
from typing import Dict

from mlagents_envs.logging_util import get_logger
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
from mlagents.trainers.exception import TrainerConfigError
from mlagents.trainers.trainer import Trainer
from mlagents.trainers.ghost.trainer import GhostTrainer
from mlagents.trainers.ghost.controller import GhostController
from mlagents.trainers.settings import TrainerSettings
from mlagents.plugins import all_trainer_types


logger = get_logger(__name__)


class TrainerFactory:
    def __init__(
        self,
        trainer_config: Dict[str, TrainerSettings],
        output_path: str,
        train_model: bool,
        load_model: bool,
        seed: int,
        param_manager: EnvironmentParameterManager,
        init_path: str = None,
        multi_gpu: bool = False,
    ):
        """
        The TrainerFactory generates the Trainers based on the configuration passed as
        input.
        :param trainer_config: A dictionary from behavior name to TrainerSettings
        :param output_path: The path to the directory where the artifacts generated by
        the trainer will be saved.
        :param train_model: If True, the Trainers will train the model and if False,
        only perform inference.
        :param load_model: If True, the Trainer will load neural networks weights from
        the previous run.
        :param seed: The seed of the Trainers. Dictates how the neural networks will be
        initialized.
        :param param_manager: The EnvironmentParameterManager that will dictate when/if
        the EnvironmentParameters must change.
        :param init_path: Path from which to load model.
        :param multi_gpu: If True, multi-gpu will be used. (currently not available)
        """
        self.trainer_config = trainer_config
        self.output_path = output_path
        self.init_path = init_path
        self.train_model = train_model
        self.load_model = load_model
        self.seed = seed
        self.param_manager = param_manager
        self.multi_gpu = multi_gpu
        self.ghost_controller = GhostController()

    def generate(self, behavior_name: str) -> Trainer:
        trainer_settings = self.trainer_config[behavior_name]
        return TrainerFactory._initialize_trainer(
            trainer_settings,
            behavior_name,
            self.output_path,
            self.train_model,
            self.load_model,
            self.ghost_controller,
            self.seed,
            self.param_manager,
            self.multi_gpu,
        )

    @staticmethod
    def _initialize_trainer(
        trainer_settings: TrainerSettings,
        brain_name: str,
        output_path: str,
        train_model: bool,
        load_model: bool,
        ghost_controller: GhostController,
        seed: int,
        param_manager: EnvironmentParameterManager,
        multi_gpu: bool = False,
    ) -> Trainer:
        """
        Initializes a trainer given a provided trainer configuration and brain parameters, as well as
        some general training session options.

        :param trainer_settings: Original trainer configuration loaded from YAML
        :param brain_name: Name of the brain to be associated with trainer
        :param output_path: Path to save the model and summary statistics
        :param keep_checkpoints: How many model checkpoints to keep
        :param train_model: Whether to train the model (vs. run inference)
        :param load_model: Whether to load the model or randomly initialize
        :param ghost_controller: The object that coordinates ghost trainers
        :param seed: The random seed to use
        :param param_manager: EnvironmentParameterManager, used to determine a reward buffer length for PPOTrainer
        :return:
        """
        trainer_artifact_path = os.path.join(output_path, brain_name)

        min_lesson_length = param_manager.get_minimum_reward_buffer_size(brain_name)

        trainer: Trainer = None  # type: ignore  # will be set to one of these, or raise

        try:
            trainer_type = all_trainer_types[trainer_settings.trainer_type]
            trainer = trainer_type(
                brain_name,
                min_lesson_length,
                trainer_settings,
                train_model,
                load_model,
                seed,
                trainer_artifact_path,
            )

        except KeyError:
            raise TrainerConfigError(
                f"The trainer config contains an unknown trainer type "
                f"{trainer_settings.trainer_type} for brain {brain_name}"
            )

        if trainer_settings.self_play is not None:
            trainer = GhostTrainer(
                trainer,
                brain_name,
                ghost_controller,
                min_lesson_length,
                trainer_settings,
                train_model,
                trainer_artifact_path,
            )
        return trainer
